package standalone

import (
	"github.com/DiracLee/dires-go/app/cmdline"
	"github.com/DiracLee/dires-go/app/connection"
	"github.com/DiracLee/dires-go/app/payload"
	"github.com/DiracLee/dires-go/ds/dict"
	"github.com/DiracLee/dires-go/ds/list"
	"github.com/DiracLee/dires-go/logger"
	"github.com/DiracLee/dires-go/syncx"
	"strconv"
)

const (
	KeyPublish     = "handelPublish"
	KeySubscribe   = "subscribe"
	KeyUnsubscribe = "unsubscribe"
)

var (
	MessageBytes     = []byte("message")
	UnsubscribeBytes = []byte("*3\n$11\nunsubscribe\n$-1\n:0\n")
)

type Hub struct {
	subs  dict.Dict // channel -> []*conn
	locks syncx.LockBucket
}

func NewHub() *Hub {
	return &Hub{
		subs:  dict.NewConcurrent(4),
		locks: syncx.NewLockBucket(16),
	}
}

func handelPublish(svr *Standalone, conn connection.Connection, args cmdline.CmdLine) payload.Payload {
	if len(args) != 2 {
		return payload.NewArgNumErrPayload(KeyPublish)
	}
	hub := svr.hub
	channel, message := string(args[0]), args[1]
	hub.locks.Lock(channel)
	defer hub.locks.Unlock(channel)

	raw, ok := hub.subs.Get(channel)
	if !ok {
		return payload.NewIntPayload(0)
	}
	subscribers, ok := raw.(list.List)
	if !ok {
		return payload.NewIntPayload(0)
	}
	subscribers.ForEach(func(i int, v interface{}) bool {
		conn, ok := v.(connection.Connection)
		if !ok {
			return true
		}
		payloadArgs := make([][]byte, 3)
		payloadArgs[0] = MessageBytes
		payloadArgs[1] = []byte(channel)
		payloadArgs[2] = message
		_, err := conn.Write(payload.NewMultiBulkPayload(payloadArgs).Bytes())
		if err != nil {
			logger.Error("conn failed to write payload")
		}
		return true
	})
	return payload.NewIntPayload(int64(subscribers.Len()))
}

func handleSubscribe(svr *Standalone, conn connection.Connection, args cmdline.CmdLine) payload.Payload {
	if len(args) == 0 {
		return payload.NewArgNumErrPayload(cmdline.CmdSubscribe)
	}
	hub := svr.hub
	channels := make([]string, 0, len(args))
	for arg := range args {
		channels = append(channels, string(arg))
	}
	hub.locks.Lock(channels...)
	defer hub.locks.Unlock(channels...)
	for _, channel := range channels {
		if subscribe0(hub, conn, channel) {
			_, err := conn.Write(makeMsg(KeySubscribe, channel, int64(conn.SubsCount())))
			if err != nil {
				logger.Error("[handleSubscribe] conn failed to write bytes")
			}
		}
	}
	return payload.NewNoPayload()
}

func handleUnsubscribe(svr *Standalone, conn connection.Connection, args cmdline.CmdLine) payload.Payload {
	hub := svr.hub
	var channels []string
	if len(args) > 0 {
		channels = make([]string, len(args))
		for i, b := range args {
			channels[i] = string(b)
		}
	} else {
		channels = conn.GetChannels()
	}

	hub.locks.Lock(channels...)
	defer hub.locks.Unlock(channels...)

	if len(channels) == 0 {
		_, err := conn.Write(UnsubscribeBytes)
		if err != nil {
			logger.Error("[handleUnsubscribe] conn failed to write bytes")
		}
		return &payload.NoPayload{}
	}

	for _, channel := range channels {
		if unsubscribe0(hub, conn, channel) {
			_, err := conn.Write(makeMsg(KeyUnsubscribe, channel, int64(conn.SubsCount())))
			if err != nil {
				logger.Error("[handleUnsubscribe] conn failed to write bytes")
			}
		}
	}
	return &payload.NoPayload{}
}

func UnsubscribeAll(hub *Hub, conn connection.Connection) {
	channels := conn.GetChannels()

	hub.locks.Lock(channels...)
	defer hub.locks.Unlock(channels...)

	for _, channel := range channels {
		unsubscribe0(hub, conn, channel)
	}
}

func subscribe0(hub *Hub, conn connection.Connection, channel string) bool {
	conn.Subscribe(channel)

	// add into hub.subs
	raw, ok := hub.subs.Get(channel)
	subscribers, _ := raw.(list.List)
	if !ok {
		subscribers = list.New()
		hub.subs.PutOrSet(channel, subscribers)
	}
	if subscribers.Contains(conn) {
		return false
	}
	subscribers.Append(conn)
	return true
}

func unsubscribe0(hub *Hub, conn connection.Connection, channel string) bool {
	conn.UnSubscribe(channel)

	// remove from hub.subs
	raw, ok := hub.subs.Get(channel)
	if ok {
		subscribers, _ := raw.(list.List)
		subscribers.RemoveValued(conn)

		if subscribers.Len() == 0 {
			// clean
			hub.subs.Remove(channel)
		}
		return true
	}
	return false
}

func makeMsg(t string, channel string, code int64) []byte {
	return []byte("*3\r\n$" + strconv.FormatInt(int64(len(t)), 10) + payload.CRLF + t + payload.CRLF +
		"$" + strconv.FormatInt(int64(len(channel)), 10) + payload.CRLF + channel + payload.CRLF +
		":" + strconv.FormatInt(code, 10) + payload.CRLF)
}
